Debugging performance discrepancy between PyTorch and JAX variants of NVDiffrast#21
Debugging performance discrepancy between PyTorch and JAX variants of NVDiffrast#21horizon-blue wants to merge 2 commits intomainfrom
Conversation
| const int32_t* triPtr = tri; | ||
| int vtxPerInstance = d.num_vertices; | ||
| rasterizeRender(NVDR_CTX_PARAMS, s, stream, posPtr, posCount, vtxPerInstance, triPtr, triCount, rangesPtr, width, height, depth, peeling_idx); | ||
| rasterizeRender(NVDR_CTX_PARAMS, s, stream, posPtr, posCount, vtxPerInstance, triPtr, triCount, ranges, width, height, depth, peeling_idx); |
There was a problem hiding this comment.
i think this change will actually modify the behavior right?
There was a problem hiding this comment.
For our benchmark, passing in 0 or ranges doesn't seem to have any difference on the output and the compute time (because we're just rendering a single scene).
I thought it was a mistake not to pass in the actual range (since we are using it in the current b3d renderer), but I can revert this change if this is intentional :)
|
Okay, I just reverted the change to Another thing that's worth mentioning is that, even though we were rendering 1000 times in our benchmark, the entire loop can still be executed in less than 0.5s, so the overhead from XLA can still dominate. To see when JAX (and XLA) starts to shine, we can make the compute graph bigger, e.g. by rendering 50,000 times instead JAX's performance number starts to catch up here, because (Though it seems like PyTorch is also introducing their own |
The following notes are modified from the related Notion card
Benchmarking script: b3d/test/test_renderer_fps.py
Before
Output of
python test/test_renderer_fps.py:First change: Using
lax.scaninstead of for loopThis should let us get rid of some overhead from XLA…
(Note:
lax.while_loopshould achieve similar effect)Related:
jax.lax.scanandjax.lax.while_loopSecond change: Removing unnecessary
cudaStreamSynchronize(stream)Disclaimer: I’m not certain about this change, since I’m new to CUDA programming.
It looks like we’re calling
cudaStreamSynchronize(stream)a lot in the definition of the JAX rasterize wrapper code (e.g.jax_rasterize_gl.cpp). However, except for debugging, we probably don’t want to block CPU until the stream has finished execution?By deleting
cudaStreamSynchronize(stream)from the C++ implementations, we can see another performance bump on the JAX rasterizer:Note 1: It seems like removing all Stream synchronization from b3d version of the renderer can result in nondeterministic CUDA error. I haven’t take a super close look into the b3d-version of the renderer (aka
JAX) to find out what are removable, so the numbers are not included above. Though even when it doesn't error out, we don't see the same performance boost:Note 2: After removing the unnecessary
cudaStreamSynchronizecall, the output of JAX NVDiffrast is still the same as the PyTorch version::) I'm just pushing the code here so people can give it a try. Even though we're only tweaking the rasterization operator here, this can give us some ideas about how to improve the performance on the overall rendering pipeline. @nishadgothoskar